{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Import custom class and functions\n",
    "from inputproducer import InputProducer\n",
    "#from tracker import TrackerVanilla\n",
    "from vgg16 import Vgg16\n",
    "from selcnn import SelCNN\n",
    "#from sgnet import GNet, SNet\n",
    "from utils import img_with_bbox, IOU_eval\n",
    "\n",
    "import numpy as np \n",
    "import tensorflow as tf\n",
    "\n",
    "import os\n",
    "\n",
    "import skimage.io\n",
    "show = skimage.io.imshow\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "\n",
    "tf.app.flags.DEFINE_integer('iter_step_sel', 200,\n",
    "                          \"\"\"Number of steps for trainning\"\"\"\n",
    "                          \"\"\"selCNN networks.\"\"\")\n",
    "tf.app.flags.DEFINE_integer('iter_step_sg', 50,\n",
    "                          \"\"\"Number of steps for trainning\"\"\"\n",
    "                          \"\"\"SGnet works\"\"\")\n",
    "tf.app.flags.DEFINE_integer('num_sel', 384,\n",
    "                          \"\"\"Number of feature maps selected.\"\"\")\n",
    "tf.app.flags.DEFINE_integer('iter_max', 200,\n",
    "\t\t\t\t\t\t\t\"\"\"Max iter times through imgs\"\"\")\n",
    "\n",
    "FLAGS = tf.app.flags.FLAGS\n",
    "\n",
    "## Define varies path\n",
    "DATA_ROOT = 'data/Dog1'\n",
    "IMG_PATH = os.path.join(DATA_ROOT, 'img')\n",
    "GT_PATH = os.path.join(DATA_ROOT, 'groundtruth_rect.txt')\n",
    "VGG_WEIGHTS_PATH = 'vgg16_weights.npz'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/xlws/repos/FCNT_TF/inputproducer.py:244: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future\n",
      "  roi =  imresize(img[y1-1:y2, x1-1:x2, :], [self.roi_params['roi_size'], self.roi_params['roi_size']], interp='bicubic')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0013186425201 max of mask\n",
      "(224, 224, 3)\n",
      "250 max convas\n",
      "0.0013186425201 max of mask\n",
      "(224, 224, 3)\n",
      "228 max convas\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/xlws/repos/FCNT_TF/inputproducer.py:85: VisibleDeprecationWarning: using a non-integer number instead of an integer will result in an error in the future\n",
      "  roi = convas[cy-half:cy+half, cx-half:cx+half, :]\n"
     ]
    }
   ],
   "source": [
    "inputProducer = InputProducer(IMG_PATH, GT_PATH)\n",
    "img, gt, t  = next(inputProducer.gen_img)\n",
    "roi,_,_,_ = inputProducer.extract_roi_deprecated(img, gt)\n",
    "\n",
    "sess = tf.Session()\n",
    "\n",
    "vgg = Vgg16(VGG_WEIGHTS_PATH, sess)\n",
    "\n",
    "lselCNN = SelCNN('sel_local', vgg.conv4_3)\n",
    "gselCNN = SelCNN('sel_global', vgg.conv5_3)\n",
    "\n",
    "lgt_M = inputProducer.gen_mask(lselCNN.pre_M_size)\n",
    "ggt_M = inputProducer.gen_mask(gselCNN.pre_M_size)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f69d774c5c0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARwAAAEZCAYAAABFOZpTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADPhJREFUeJzt3U2MVeUdx/HfD0dUmAQHIjPK+FJSWxemITZ1Qxc29oV0\ng3HRGjfqwriorUutG7bqgpSNG0VDmxpjTSy0SSttWBjTqERLRUUlMVhenAEJKGjQUf5d3EOZTu/w\nHO4987+Xe7+fZMKdM8+983CAL+ece+YcR4QAIMOiXk8AwPAgOADSEBwAaQgOgDQEB0AaggMgzUg3\nT7a9TtJv1ArX5oh4tM0Y3ncHhkxEuN1yd3oeju1Fkt6XdKukQ5J2SrojIt6dM47gAENmvuB0s0t1\ns6S9EfFhRMxIelbS+i5eD8CA6yY4qyTtn/X5gWoZALTFQWMAaboJzkFJ18z6fLJaBgBtdROcnZK+\nafta24sl3SFpWzPTAjCIOn5bPCK+tn2/pO06+7b4nsZmBmDgdPy2eO1vwNviwNBZiLfFAeC8EBwA\naQgOgDQEB0AaggMgDcEBkIbgAEhDcACkITgA0hAcAGkIDoA0BAdAGoIDIA3BAZCG4ABIQ3AApCE4\nANIQHABpCA6ANAQHQBqCAyANwQGQhuAASENwAKQhOADSEBwAaQgOgDQEB0AaggMgDcEBkIbgAEhD\ncACkITgA0hAcAGkIDoA0BAdAmpFunmx7n6RPJJ2WNBMRNzcxKQCDqavgqBWaWyLiWBOTATDYut2l\ncgOvAWBIdBuLkPQ32ztt39vEhAAMrm53qdZGxEe2r1ArPHsi4uUmJgZg8HS1hRMRH1W/HpH0giQO\nGgOYV8fBsb3E9mj1eKmkH0t6q6mJARg83exSjUt6wXZUr/P7iNjezLQADCJHxMJ+g1aQAAyRiHC7\n5bylDSANwQGQhuAASENwAKQhOADSEBwAaQgOgDQEB0AaggMgDcEBkIbgAEhDcACkITgA0hAcAGkI\nDoA0BAdAGoIDIA3BAZCG4ABIQ3AApCE4ANIQHABpCA6ANAQHQBqCAyANwQGQhuAASDPS6wlAWrSo\n3P2LLrqoOGZkpPzHefHFFzfyOnXmXMfXX39dHPPVV18Vx8zMzKR9r4gojkF7bOEASENwAKQhOADS\nEBwAaQgOgDQEB0AaggMgDcEBkIYT//rA0qVLi2OWL19eHDM+Pl4cs2rVquKYiYmJ4phly5YVx9Q5\nQe7YsWPFMQcPHmxkzJEjR4pjjh49Whxz6tSp4hi0V9zCsb3Z9rTtN2ctG7O93fZ7tl+0Xf7bB2Do\n1dmlelrST+Yse0jS3yPi25J2SPp10xMDMHiKwYmIlyXN3e5dL2lL9XiLpNsanheAAdTpQeOVETEt\nSRExJWllc1MCMKiaepeKH58FUNRpcKZtj0uS7QlJh5ubEoBBVTc4rj7O2Cbp7urxXZK2NjgnAAOq\nztviz0j6h6Rv2f637XskPSLpR7bfk3Rr9TkAnFPxxL+IuHOeL/2w4bkMrcWLFxfHjI2NFcdcddVV\nxTE33HBDccz1119fHHPFFVcUx5w+fbo45tChQ8UxS5YsKY6pczW/Oifsffrpp428DtrjRxsApCE4\nANIQHABpCA6ANAQHQBqCAyANwQGQhuAASMMV//pAndvvjo6OFsfUueLf6tWri2NuvPHG4pjJycni\nmDq3zb388suLY7744ovimDpX6vv444+LY+rcUhmdYwsHQBqCAyANwQGQhuAASENwAKQhOADSEBwA\naQgOgDSc+NcH6pxsdskllxTH1Dk5cMWKFcUxdW71e+WVVxbH1Dnx7/PPPy+OqXNyYJ2rAtZZh4sW\n8X/wQmLtAkhDcACkITgA0hAcAGkIDoA0BAdAGoIDIA3BAZCGE//6QEQUx9S5bW6dE+2+/PLL4pg6\nV9irM2ZmZibtderc6rfOmDp/FugcWzgA0hAcAGkIDoA0BAdAGoIDIA3BAZCG4ABIQ3AApOHEvz7Q\n1JXxjh07Vhxz6NCh4pixsbHimM8++6w4ps6Jdvv37y+OmZqaKo45fvx4ccypU6eKY+qcYInOFbdw\nbG+2PW37zVnLNtg+YPuN6mPdwk4TwCCos0v1tKSftFm+MSJuqj7+2vC8AAygYnAi4mVJ7bbV3fx0\nAAyybg4a3297l+0nbS9rbEYABlanwXlc0uqIWCNpStLG5qYEYFB1FJyIOBJnf47/CUnfa25KAAZV\n3eBYs47Z2J59p7TbJb3V5KQADKbieTi2n5F0i6QVtv8taYOkH9heI+m0pH2S7lvAOQIYEMXgRMSd\nbRY/vQBzGVp1rmj3ySefFMfUOUFu7969jcxn+fLlxTF1TqI7fPhwccwHH3zQyOucOHGiOKbOSZjo\nHD/aACANwQGQhuAASENwAKQhOADSEBwAaQgOgDQEB0AaL/StTW1z79SCSy+9tDhmdHS0OGbZsvIP\n7de5ml+dMZdddllxTJ2/WydPniyOqXMlwzpj6pz4V2dMnRMjh11EtL18DVs4ANIQHABpCA6ANAQH\nQBqCAyANwQGQhuAASENwAKThxD8AjePEPwA9R3AApCE4ANIQHABpCA6ANAQHQBqCAyANwQGQhuAA\nSENwAKQhOADSEBwAaQgOgDQEB0AaggMgDcEBkIbgAEhDcACkKQbH9qTtHbbftr3b9q+q5WO2t9t+\nz/aLtss3tgYw1IrXNLY9IWkiInbZHpX0uqT1ku6RdDQiHrP9oKSxiHiozfO5pjEwZDq+pnFETEXE\nrurxSUl7JE2qFZ0t1bAtkm5rZqoABtV5HcOxfZ2kNZJekTQeEdNSK0qSVjY9OQCDpXZwqt2p5yU9\nUG3pzN1VYtcJwDnVCo7tEbVi87uI2FotnrY9Xn19QtLhhZkigEFRdwvnKUnvRMSmWcu2Sbq7enyX\npK1znwQAs9V5l2qtpJck7VZrtykkPSzpNUnPSbpa0oeSfhYRx9s8n10tYMjM9y4Vt/oF0Dhu9Qug\n5wgOgDQEB0AaggMgDcEBkIbgAEhDcACkITgA0hAcAGkIDoA0BAdAGoIDIA3BAZCG4ABIQ3AApCE4\nANIQHABpCA6ANAQHQBqCAyANwQGQhuAASENwAKQhOADSEBwAaQgOgDQEB0AaggMgDcEBkIbgAEhD\ncACkITgA0hAcAGkIDoA0BAdAGoIDIE0xOLYnbe+w/bbt3bZ/WS3fYPuA7Teqj3ULP10AFzJHxLkH\n2BOSJiJil+1RSa9LWi/p55JORMTGwvPP/Q0ADJyIcLvlIzWeOCVpqnp80vYeSauqL7d9UQBo57yO\n4di+TtIaSa9Wi+63vcv2k7aXNTw3AAOmdnCq3annJT0QESclPS5pdUSsUWsL6Jy7VgBQPIYjSbZH\nJP1Z0l8iYlObr18r6U8R8Z02X+MYDjBk5juGU3cL5ylJ78yOTXUw+YzbJb3V+fQADIM671KtlfSS\npN2Sovp4WNKdah3POS1pn6T7ImK6zfPZwgGGzHxbOLV2qbpBcIDh0+0uFQB0jeAASENwAKQhOADS\nEBwAaQgOgDQEB0AaggMgDcEBkIbgAEhDcACkITgA0hAcAGkIDoA0BAdAGoIDIA3BAZCG4ABIQ3AA\npFnwaxoDwBls4QBIQ3AApCE4ANKkBsf2Otvv2n7f9oOZ37tTtvfZ/pftf9p+rdfzmY/tzbanbb85\na9mY7e2237P9ou1lvZzjXPPMeYPtA7bfqD7W9XKOs9metL3D9tu2d9v+VbW839fz3Hn/slqevq7T\nDhrbXiTpfUm3SjokaaekOyLi3ZQJdMj2B5K+GxHHej2Xc7H9fUknJf32zD3ebT8q6WhEPFYFfiwi\nHurlPGebZ84bJJ2IiI09nVwb1e2tJyJil+1RSa9LWi/pHvX3ep5v3j9X8rrO3MK5WdLeiPgwImYk\nPavWb7rfWRfArmdEvCxpbhTXS9pSPd4i6bbUSRXMM2eptc77TkRMRcSu6vFJSXskTar/13O7ea+q\nvpy6rjP/Ia2StH/W5wd09jfdz0LS32zvtH1vrydznlaeud97RExJWtnj+dR1v+1dtp/st92TM2xf\nJ2mNpFckjV8o63nWvF+tFqWu677/n7sPrI2ImyT9VNIvqt2AC9WFcNLV45JWR8QaSVOS+nHXalTS\n85IeqLYY5q7XvlzPbeadvq4zg3NQ0jWzPp+slvW1iPio+vWIpBfU2jW8UEzbHpf+ux9/uMfzKYqI\nI3H2wOITkr7Xy/nMZXtErX+0v4uIrdXivl/P7ebdi3WdGZydkr5p+1rbiyXdIWlb4vc/b7aXVP8r\nyPZSST+W9FZvZ3VO1v/uk2+TdHf1+C5JW+c+oQ/8z5yrf7Bn3K7+W99PSXonIjbNWnYhrOf/m3cv\n1nXqjzZUb7ttUit0myPikbRv3gHb31BrqyYkjUj6fb/O2fYzkm6RtELStKQNkv4o6Q+Srpb0oaSf\nRcTxXs1xrnnm/AO1jjGclrRP0n1njo/0mu21kl6StFutvxMh6WFJr0l6Tv27nueb951KXtf8LBWA\nNBw0BpCG4ABIQ3AApCE4ANIQHABpCA6ANAQHQJr/AGnTk5nIUIB0AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f6a2445fc88>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show(lgt_M)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lgt_M = lgt_M.astype(np.float32)[np.newaxis,:,:,np.newaxis]\n",
    "ggt_M = ggt_M.astype(np.float32)[np.newaxis,:,:,np.newaxis]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np \n",
    "\n",
    "from utils import variable_on_cpu, variable_with_weight_decay\n",
    "\n",
    "\n",
    "class SGNet:\n",
    "\n",
    "    # Define class level optimizer\n",
    "    #lr = 1e-6\n",
    "    #optimizer = tf.train.GradientDescentOptimizer(lr)\n",
    "\n",
    "    def __init__(self, scope, vgg_conv_shape):\n",
    "        \"\"\"\n",
    "        Base calss for SGNet, defines the network structure\n",
    "        \"\"\"\n",
    "        self.scope = scope\n",
    "        self.params = {\n",
    "        'num_fms': 200, # number of selected featrue maps, inputs of the network\n",
    "        'wd': 0.5, # L2 regulization coefficient\n",
    "        }\n",
    "        self.variables = []\n",
    "        with tf.variable_scope(scope) as scope:\n",
    "            self.pre_M = self._build_graph(vgg_conv_shape)\n",
    "\n",
    "    def _build_graph(self, vgg_conv_shape):\n",
    "        \"\"\"\n",
    "        Define Structure. \n",
    "        The first additional convolutional\n",
    "        layer has convolutional kernels of size 9×9 and outputs\n",
    "        36 feature maps as the input to the next layer. The second\n",
    "        additional convolutional layer has kernels of size 5 × 5\n",
    "        and outputs the foreground heat map of the input image.\n",
    "        ReLU is chosen as the nonlinearity for these two layers.\n",
    "\n",
    "        Args:\n",
    "            vgg_conv_shape: \n",
    "        Returns:\n",
    "            conv2: \n",
    "        \"\"\"\n",
    "        self.variables = []\n",
    "        self.kernel_weights = []\n",
    "        out_num = vgg_conv_shape[-1]\n",
    "        self.input_maps = tf.placeholder(tf.float32, shape=vgg_conv_shape,\n",
    "            name='selected_maps')\n",
    "        #assert vgg_conv_shape[-1] == self.params['num_fms']\n",
    "        \n",
    "        with tf.name_scope('conv1') as scope:\n",
    "            kernel = tf.Variable(tf.truncated_normal([9,9,out_num,36], dtype=tf.float32,\n",
    "                                                     stddev=1e-1), name='weights')\n",
    "            conv = tf.nn.conv2d(self.input_maps, kernel, [1, 1, 1, 1], padding='SAME')\n",
    "            biases = tf.Variable(tf.constant(0.0, shape=[36], dtype=tf.float32),\n",
    "                                 trainable=True, name='biases')\n",
    "            out = tf.nn.bias_add(conv, biases)\n",
    "            conv1 = tf.nn.relu(out, name=scope)\n",
    "            self.variables += [kernel, biases]\n",
    "            self.kernel_weights += [kernel]\n",
    "            print(conv1.get_shape().as_list(), 'conv1 shape')\n",
    "\n",
    "\n",
    "        with tf.name_scope('conv2') as scope:\n",
    "            kernel = tf.Variable(tf.truncated_normal([5,5,36,1], dtype=tf.float32,\n",
    "                                                     stddev=1e-1), name='weights')\n",
    "            conv = tf.nn.conv2d(conv1, kernel , [1, 1, 1, 1], padding='SAME')\n",
    "            print(conv.get_shape().as_list(), 'conv shape')\n",
    "            biases = tf.Variable(tf.constant(0.0, shape=[1], dtype=tf.float32),\n",
    "                                 trainable=True, name='biases')\n",
    "            out = tf.nn.bias_add(conv, biases)\n",
    "            conv2 = tf.nn.relu(out, name=scope)\n",
    "            self.variables += [kernel, biases]\n",
    "            self.kernel_weights += [kernel]\n",
    "\n",
    "        print('Shape of the out put heat map for %s is %s'%(self.scope, conv2.get_shape().as_list()))\n",
    "        return conv2\n",
    "\n",
    "    def loss(self, gt_M):\n",
    "        \"\"\"Returns Losses for the current network.\n",
    "\n",
    "        Args:\n",
    "            gt_M: Tensor, ground truth heat map.\n",
    "\n",
    "        Returns:\n",
    "            Loss: \n",
    "        \"\"\"\n",
    "\n",
    "        # Assertion\n",
    "        with tf.name_scope(self.scope) as scope:\n",
    "            beta = tf.constant(self.params['wd'], name='beta')\n",
    "            loss_rms = tf.reduce_mean(tf.square(tf.sub(gt_M, self.pre_M))) \n",
    "            loss_wd = [tf.reduce_mean(tf.square(w)) for w in self.kernel_weights]\n",
    "            loss_wd = beta * tf.add_n(loss_wd)\n",
    "            total_loss = loss_rms + loss_wd\n",
    "        return total_loss\n",
    "\n",
    "    @classmethod\n",
    "    def eadge_RP():\n",
    "        \"\"\"\n",
    "        This method propose a series of ROI along eadges\n",
    "        of a given frame. This should be called when particle \n",
    "        confidence below a critical value, which possibly accounts\n",
    "        for object re-appearance.\n",
    "        \"\"\"\n",
    "        pass\n",
    "\n",
    "\n",
    "\n",
    "class GNet(SGNet):\n",
    "    def __init__(self, scope, vgg_conv_shape):\n",
    "        \"\"\"\n",
    "        Fixed params once trained in the first frame\n",
    "        \"\"\"\n",
    "        super(GNet, self).__init__(scope, vgg_conv_shape)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "class SNet(SGNet):\n",
    "    def __init__(self, scope, vgg_conv_shape):\n",
    "        \"\"\"\n",
    "        Initialized in the first frame\n",
    "        \"\"\"\n",
    "        super(SNet, self).__init__(scope, vgg_conv_shape)\n",
    "\n",
    "    def adaptive_finetune(self, sess, updated_gt_M):\n",
    "        \"\"\"Finetune SNet with updated_gt_M.\"\"\"\n",
    "\n",
    "        pass\n",
    "\n",
    "    def descrimtive_finetune(self, sess, init_gt_M, cur_):\n",
    "        pass\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "s_sel_maps = np.random.random((1, 28,28,FLAGS.num_sel))\n",
    "g_sel_maps = np.random.random((1, 14,14,FLAGS.num_sel))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "[n.name for n in tf.get_default_graph().as_graph_def().node]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 14, 14, 36] conv1 shape\n",
      "[1, 14, 14, 1] conv shape\n",
      "Shape of the out put heat map for GNet is [1, 14, 14, 1]\n",
      "[1, 28, 28, 36] conv1 shape\n",
      "[1, 28, 28, 1] conv shape\n",
      "Shape of the out put heat map for SNet is [1, 28, 28, 1]\n"
     ]
    }
   ],
   "source": [
    "# Instantiate G and S networks by sending selected saliency maps.\n",
    "gnet = GNet('GNet', g_sel_maps.shape)\n",
    "snet = SNet('SNet', s_sel_maps.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "[v.get_shape().as_list() for v in gnet.kernel_weights]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "\n",
    "sess.run(tf.initialize_all_variables())\n",
    "#sess.run(tf.report_uninitialized_variables())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#uninit_SGNet_vars = gnet.variables + snet.variables\n",
    "#init_SGNet_vars_op = tf.initialize_variables(uninit_SGNet_vars)\n",
    "#sess.run(init_SGNet_vars_op)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([Dimension(1), Dimension(14), Dimension(14), Dimension(1)])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gnet.pre_M.get_shape()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor 'Sub:0' shape=(1, 28, 28, 1) dtype=float32>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.sub(snet.pre_M, lgt_M)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "sloss = snet.loss(lgt_M)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "gloss = gnet.loss(ggt_M)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "total_losses = sloss + gloss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sgNet_vars = gnet.variables + snet.variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lr = 1e-8\n",
    "optimizer = tf.train.GradientDescentOptimizer(lr)\n",
    "\n",
    "train_op = optimizer.minimize(total_losses, var_list= sgNet_vars)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "\n",
    "feed_dict = {gnet.input_maps: g_sel_maps, snet.input_maps: s_sel_maps}\n",
    "loss = sess.run(total_losses, feed_dict = feed_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "211.17242"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "snet.input_maps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "gnet.input_maps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Train sgnet by minimize the loss\n",
    "Loss = Lg + Ls\n",
    "where Li = |pre_Mi - gt_M|**2 + Weights_decay_term_i\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "lr = 1e-8\n",
    "optimizer = tf.train.GradientDescentOptimizer(lr)\n",
    "\n",
    "train_op = optimizer.minimize(total_losses, var_list= sgNet_vars)\n",
    "\n",
    "# initializa SGNet.variables\n",
    "\n",
    "feed_dic = {gnet.input_maps: g_sel_maps, snet.input_maps: s_sel_maps}\n",
    "loss = sess.run(total_losses, feed_dic = feed_dic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "gt_M = gt_M.astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [Root]",
   "language": "python",
   "name": "Python [Root]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
